#!/usr/bin/env python3
import argparse, sys, math
import numpy as np, pandas as pd
from astropy.io import fits
from astropy.cosmology import FlatLambdaCDM

COSMO = FlatLambdaCDM(H0=70, Om0=0.3)  # simple, fixed for binning

SIZE_COL_CANDIDATES = [
    # common structural names (arcsec)
    "RE_G", "RE_R", "RE_I", "RE_Z", "RE_arcsec", "RE_ARCSEC", "RE", "R_e",
    "REFF", "REFF_ARCSEC",
    # fallbacks (arcsec-ish proxies)
    "FLUX_RADIUS", "HALF_LIGHT_RADIUS", "HLR",
    # last resort (pixels): will be ignored unless a pixel scale is known
]

def read_fits_table(path):
    h = fits.open(path, memmap=True)
    for hd in h:
        if hasattr(hd, "columns"):
            D = hd.data
            cols = list(hd.columns.names)
            df = pd.DataFrame({c: D[c] for c in cols})
            return df
    raise RuntimeError(f"No binary table HDU found in {path}")

def pick_col(cols, options):
    for o in options:
        if o in cols:
            return o
    return None

def kpc_per_arcsec(z):
    if not np.isfinite(z) or z <= 0:
        return np.nan
    return COSMO.kpc_proper_per_arcmin(z).value / 60.0  # kpc/arcsec

def assign_bins(df):
    # fixed edges (declared once)
    rg_edges = [5.0, 7.5, 10.0, 12.5, 15.0]
    ms_edges = [10.2, 10.5, 10.8, 11.1]

    df["R_G_bin"] = pd.cut(df["R_G_kpc"], bins=rg_edges, right=False,
                           labels=[f"{rg_edges[i]}–{rg_edges[i+1]}" for i in range(len(rg_edges)-1)])
    df["Mstar_bin"] = pd.cut(df["Mstar_log10"], bins=ms_edges, right=False,
                             labels=[f"{ms_edges[i]}–{ms_edges[i+1]}" for i in range(len(ms_edges)-1)])
    return df

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--bright", required=True, help="KiDS_DR4_brightsample.fits")
    ap.add_argument("--lephare", required=True, help="KiDS_DR4_brightsample_LePhare.fits")
    ap.add_argument("--struct", required=True, help="KiDS DR4 structural fits (Sérsic sizes)")
    ap.add_argument("--out", default="data/lenses_true.csv")
    ap.add_argument("--max-rows", type=int, default=None)
    args = ap.parse_args()

    B = read_fits_table(args.bright)
    L = read_fits_table(args.lephare)
    S = read_fits_table(args.struct)

    # normalize column names
    for df in (B,L,S):
        df.columns = [str(c) for c in df.columns]
    # required in Bright
    for req in ("ID","RAJ2000","DECJ2000"):
        if req not in B.columns:
            sys.exit(f"Bright catalog missing '{req}'")
    if "zphot_ANNz2" not in B.columns:
        sys.exit("Bright catalog missing 'zphot_ANNz2' (photo-z)")

    # choose structural size column (arcsec)
    size_col = pick_col(S.columns, SIZE_COL_CANDIDATES)
    if size_col is None:
        sys.exit(f"Could not find a structural size column in {args.struct}. "
                 f"Tried: {SIZE_COL_CANDIDATES}")

    # join on ID (preferred)
    if "ID" not in S.columns or "ID" not in L.columns:
        sys.exit("Expected 'ID' in structural and LePhare tables for a clean join.")

    D = (B[["ID","RAJ2000","DECJ2000","zphot_ANNz2"]]
           .merge(L[["ID","MASS_MED"]], on="ID", how="inner")
           .merge(S[["ID", size_col]], on="ID", how="inner"))
    D.rename(columns={"RAJ2000":"ra_deg","DECJ2000":"dec_deg",
                      "zphot_ANNz2":"z_lens","MASS_MED":"Mstar_med",
                      size_col:"size_arcsec"}, inplace=True)
    if args.max_rows:
        D = D.iloc[:args.max_rows].copy()

    # compute log10 stellar mass and RG in kpc
    D["Mstar_log10"] = np.log10(D["Mstar_med"].astype(float))
    D["kpc_per_arcsec"] = D["z_lens"].astype(float).map(kpc_per_arcsec)
    D["R_G_kpc"] = D["size_arcsec"].astype(float) * D["kpc_per_arcsec"].astype(float)

    # clean finite rows
    D = D.replace([np.inf,-np.inf], np.nan)
    D = D.dropna(subset=["ra_deg","dec_deg","z_lens","Mstar_log10","R_G_kpc"])

    # assign bins
    D = assign_bins(D)

    # lens id
    D["lens_id"] = D["ID"].apply(lambda x: f"KiDSDR4_ID{x}")

    # order columns
    out = D[["lens_id","ra_deg","dec_deg","z_lens","R_G_kpc","Mstar_log10","R_G_bin","Mstar_bin"]].copy()
    out.to_csv(args.out, index=False)
    print(f"Wrote {args.out} with {len(out)} rows.")
    print(out.groupby(["R_G_bin","Mstar_bin"]).size())

if __name__ == "__main__":
    main()
